import os
import numpy as np
import warnings
import random
from glob import glob


import skimage
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image
import rasterio
from rasterio import logging
import ast
import json
import pandas as pd

log = logging.getLogger()
log.setLevel(logging.ERROR)

Image.MAX_IMAGE_PIXELS = None
warnings.simplefilter('ignore', Image.DecompressionBombWarning)




def get_image_center_coordinates(image_path):
    with rasterio.open(image_path) as src:
        transform = src.transform
        crs = src.crs

        width = src.width
        height = src.height

        center_x = width // 2
        center_y = height // 2

        lon, lat = transform * (center_x, center_y)

        bounds = src.bounds

        lon_min, lat_max = bounds.left, bounds.top
        lon_max, lat_min = bounds.right, bounds.bottom

    return lat, lon, transform, lon_min, lat_max, lon_max, lat_min

class SatelliteDataset(Dataset):
    """
    Abstract class.
    """
    def __init__(self, in_c):
        self.in_c = in_c

    @staticmethod
    def build_transform(is_train, input_size, mean, std):
        """
        Builds train/eval data transforms for the dataset class.
        :param is_train: Whether to yield train or eval data transform/augmentation.
        :param input_size: Image input size (assumed square image).
        :param mean: Per-channel pixel mean value, shape (c,) for c channels
        :param std: Per-channel pixel std. value, shape (c,)
        :return: Torch data transform for the input image before passing to model
        """
        # mean = IMAGENET_DEFAULT_MEAN
        # std = IMAGENET_DEFAULT_STD

        # train transform
        interpol_mode = transforms.InterpolationMode.BICUBIC

        t = []
        if is_train:
            t.append(transforms.ToTensor())
            t.append(transforms.Normalize(mean, std))
            t.append(
                transforms.RandomResizedCrop(input_size, scale=(0.2, 1.0), interpolation=interpol_mode),  # 3 is bicubic
            )
            t.append(transforms.RandomHorizontalFlip())
            return transforms.Compose(t)

        # eval transform
        if input_size <= 224:
            crop_pct = 224 / 256
        else:
            crop_pct = 1.0
        size = int(input_size / crop_pct)

        t.append(transforms.ToTensor())
        t.append(transforms.Normalize(mean, std))
        t.append(
            transforms.Resize(size, interpolation=interpol_mode),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(input_size))

        # t.append(transforms.Normalize(mean, std))
        return transforms.Compose(t)


class CustomDatasetFromImages(SatelliteDataset):
    mean = [0.4182007312774658, 0.4214799106121063, 0.3991275727748871]
    std = [0.28774282336235046, 0.27541765570640564, 0.2764017581939697]

    def __init__(self, csv_path, transform):
        """
        Creates Dataset for regular RGB image classification (usually used for fMoW-RGB dataset).
        :param csv_path: csv_path (string): path to csv file.
        :param transform: pytorch transforms for transforms and tensor conversion.
        """
        super().__init__(in_c=3)
        # Transforms
        self.transforms = transform
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=0)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 1])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 0])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)
        # Transform the image
        img_as_tensor = self.transforms(img_as_img)
        # Get label(class) of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

class SentinelNormalize:
    """
    Normalization for Sentinel-2 imagery, inspired from
    https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/bigearthnet_dataset.py#L111
    """

    def __init__(self, mean, std):
        self.mean = np.array(mean)
        self.std = np.array(std)

    def __call__(self, x, *args, **kwargs):
        min_value = self.mean - 2 * self.std
        max_value = self.mean + 2 * self.std
        img = (x - min_value) / (max_value - min_value) * 255.0
        img = np.clip(img, 0, 255).astype(np.uint8)
        return img

class ImageNormalizer:
    def __init__(self, mean, std):
        """
        Initialize the normalizer with mean and std for each channel.

        Parameters:
        - mean: list or array of mean values for each channel, e.g., [mean_r, mean_g, mean_b]
        - std: list or array of standard deviation values for each channel, e.g., [std_r, std_g, std_b]
        """
        self.mean = np.array(mean, dtype=np.float32)
        self.std = np.array(std, dtype=np.float32)

    def __call__(self, img, *args, **kwargs):
        """
        Normalize an image with the initialized mean and std.

        Parameters:
        - img: numpy array of shape (h, w, 3)

        Returns:
        - Normalized numpy array of shape (h, w, 3)
        """
        img = img.astype(np.float32)
        normalized_img = (img - self.mean) / self.std
        return normalized_img


class SentinelStreetViewPairedImageDataset(SatelliteDataset):
    label_types = ['value', 'one-hot']
    sentinel_mean = [1370.19151926, 1184.3824625, 1120.77120066, 1136.26026392,
                     1263.73947144, 1645.40315151, 1846.87040806, 1762.59530783,
                     1972.62420416, 582.72633433, 14.77112979, 1732.16362238, 1247.91870117]
    sentinel_std = [633.15169573, 650.2842772, 712.12507725, 965.23119807,
                    948.9819932, 1108.06650639, 1258.36394548, 1233.1492281,
                    1364.38688993, 472.37967789, 14.3114637, 1310.36996126, 1087.6020813]

    sv_mean = [135.47184384, 142.28621995, 145.12271992]
    sv_std = [59.19978436, 59.35279219, 72.02295734]

    def __init__(self,
                 csv_path,
                 sv_images_path,
                 sat_images_path,
                 sat_transform,
                 sv_transform,
                 repeat=1,
                 len_sv=20,
                 dropped_bands=None
                 ):
        """
        Creates dataset for multi-spectral single image classification.
        Usually used for fMoW-Sentinel dataset.
        :param csv_path: path to csv file.
        :param transform: pytorch Transform for transforms and tensor conversion
        :param years: List of years to take images from, None to not filter
        :param categories: List of categories to take images from, None to not filter
        :param label_type: 'values' for single label, 'one-hot' for one hot labels
        :param masked_bands: List of indices corresponding to which bands to mask out
        :param dropped_bands:  List of indices corresponding to which bands to drop from input image tensor
        """
        super().__init__(in_c=13)
        self.df = pd.read_csv(csv_path)
        self.sv_images_path = sv_images_path
        self.sat_images_path = sat_images_path
        self.len_sv = len_sv
        self.in_c = self.in_c - len(dropped_bands)
        self.sat_transform = sat_transform
        self.sv_transform = sv_transform
        self.dropped_bands = dropped_bands
        self.repeat = repeat
        self.df = pd.concat([self.df] * self.repeat, ignore_index=True)

    def __len__(self):
        return len(self.df)

    def open_image(self, img_path):
        with rasterio.open(img_path) as data:
            # img = data.read(
            #     out_shape=(data.count, self.resize, self.resize),
            #     resampling=Resampling.bilinear
            # )
            img = data.read()  # (c, h, w)

        return img.transpose(1, 2, 0).astype(np.float32)  # (h, w, c)

    def process_time_stamp(self, year, month, hour, min_year=2002):

        return np.array([year - min_year, month - 1, hour])

    def __getitem__(self, idx):
        """
        Gets image (x,y) pair given index in dataset.
        :param idx: Index of (image, label) pair in dataset dataframe. (c, h, w)
        :return: Torch Tensor image, and integer label as a tuple.
        """

        selection = self.df.iloc[idx]
        panoids = ast.literal_eval(selection["panoid"])
        sv_coordinate = ast.literal_eval(selection["sv_coordinates"])
        patch_number = selection["patch"]
        city = selection["city"]

        combined_panoids_sv_coordinate = list(zip(panoids, sv_coordinate))
        sv_location_information = []
        sv_temporal_information = []
        if len(combined_panoids_sv_coordinate) < self.len_sv:
            sv_image_as_tensor_list = []
            for panoid in panoids:
                sv_all_files_in_folder = os.listdir(os.path.join(self.sv_images_path, city + "GSV", panoid))
                sv_path = os.path.join(self.sv_images_path, city + "GSV", panoid, random.choice(list(
                    filter(lambda x: x.endswith('.jpg'), sv_all_files_in_folder))))
                meta_file_path = os.path.join(self.sv_images_path, city + "GSV", panoid, "meta_data.json")
                with open(meta_file_path, 'r') as f:
                    meta_data = json.load(f)

                lon = meta_data['lon']
                lat = meta_data['lat']
                year = meta_data['year']
                month = meta_data['month']

                sv_location_information.append([float(lon), float(lat)])
                sv_temporal_information.append([int(year), int(month)])

                # sv_path = os.path.join(self.sv_images_path, city+"GSV", panoid, f"{panoid}_90.jpg")
                sv_image = skimage.io.imread(sv_path).astype(np.float32)  # H,W,C
                sv_image_as_tensor = self.sv_transform(sv_image)  # C,H,W
                sv_image_as_tensor_list.append(sv_image_as_tensor.unsqueeze(0))  # 1,C,H,W

            sv_image_as_tensor = torch.cat(sv_image_as_tensor_list, dim=0)  # n,C,H,W

            _, c, h, w = sv_image_as_tensor.shape  # N, C, H, W
            n = self.len_sv - len(combined_panoids_sv_coordinate)
            pad_sv_tensor = torch.zeros((n, c, h, w))

            sv_image_as_tensor = torch.cat([sv_image_as_tensor, pad_sv_tensor], dim=0)  # N+n, C,H,W
            sv_coordinate_as_tensor = torch.tensor(sv_coordinate)
            pad_sv_coordinate_tensor = torch.ones((n)) * -1.0
            sv_coordinate_as_tensor = torch.cat([sv_coordinate_as_tensor, pad_sv_coordinate_tensor], dim=0)

            sv_location_information_as_tensor = torch.tensor(sv_location_information)
            sv_temporal_information_as_tensor = torch.tensor(sv_temporal_information)
            pad_sv_location_information_as_tensor = torch.zeros((n, 2))
            pad_sv_temporal_information_as_tensor = torch.zeros((n, 2))

            sv_location_information_as_tensor = torch.cat(
                [sv_location_information_as_tensor, pad_sv_location_information_as_tensor], dim=0)
            sv_temporal_information_as_tensor = torch.cat(
                [sv_temporal_information_as_tensor, pad_sv_temporal_information_as_tensor], dim=0)

        else:
            sv_image_as_tensor_list = []
            combined_panoids_sv_coordinate_need = random.sample(combined_panoids_sv_coordinate, self.len_sv)
            panoids = [tup[0] for tup in combined_panoids_sv_coordinate_need]
            sv_coordinate = [tup[1] for tup in combined_panoids_sv_coordinate_need]

            for panoid in panoids:
                sv_all_files_in_folder = os.listdir(os.path.join(self.sv_images_path, city + "GSV", panoid))
                sv_path = os.path.join(self.sv_images_path, city + "GSV", panoid, random.choice(list(
                    filter(lambda x: x.endswith('.jpg'), sv_all_files_in_folder))))
                meta_file_path = os.path.join(self.sv_images_path, city + "GSV", panoid, "meta_data.json")
                with open(meta_file_path, 'r') as f:
                    meta_data = json.load(f)

                lon = meta_data['lon']
                lat = meta_data['lat']
                year = meta_data['year']
                month = meta_data['month']

                sv_location_information.append([float(lon), float(lat)])
                sv_temporal_information.append([int(year), int(month)])
                sv_image = skimage.io.imread(sv_path).astype(np.float32)  # H,W,C
                sv_image_as_tensor = self.sv_transform(sv_image)
                sv_image_as_tensor_list.append(sv_image_as_tensor.unsqueeze(0))

            sv_image_as_tensor = torch.cat(sv_image_as_tensor_list, dim=0)
            sv_coordinate_as_tensor = torch.tensor(sv_coordinate)
            sv_location_information_as_tensor = torch.tensor(sv_location_information)
            sv_temporal_information_as_tensor = torch.tensor(sv_temporal_information)

        time_paths = os.listdir(os.path.join(self.sat_images_path, f"{city}_clip"))
        while True:
            time_path = random.choice(time_paths)
            image_path = os.path.join(self.sat_images_path, f"{city}_clip", time_path, patch_number + ".tif")

            images = self.open_image(image_path)  # (h, w, c)
            if not np.all(images == 0):
                break
            else:
                time_paths.remove(time_path)
        sat_lon, sat_lat, transform, lon_min, lat_max, lon_max, lat_min = get_image_center_coordinates(
            image_path)  # obtain the central point lon and lat of sat imagery

        bbox_information = torch.tensor([lon_min, lat_max, lon_max,
                                         lat_min])  # including the upper left coords and bottom right coords of the image
        sat_transform = torch.tensor(list(transform)[:6]).unsqueeze(0)

        # labels = self.categories.index(selection['category'])

        img_as_tensor = self.sat_transform(images)  # (c, h, w)

        if self.dropped_bands is not None:
            keep_idxs = [i for i in range(img_as_tensor.shape[0]) if i not in self.dropped_bands]
            img_as_tensor = img_as_tensor[keep_idxs, :, :]

        timestamps = self.process_time_stamp(int(time_path.split("_")[0]), int(time_path.split("_")[1]),
                                             int(time_path.split("_")[2]))

        timestamps = torch.from_numpy(timestamps)

        img_as_tensor = torch.nan_to_num(img_as_tensor, nan=0.0)
        sv_image_as_tensor = torch.nan_to_num(sv_image_as_tensor, nan=0.0)
        sv_coordinate_as_tensor = torch.nan_to_num(sv_coordinate_as_tensor, nan=0.0)
        sv_location_information_as_tensor = torch.nan_to_num(sv_location_information_as_tensor, nan=0.0)

        sample = {
            'images': img_as_tensor.float(),
            "sv_images": sv_image_as_tensor.float(),
            'sv_coordinates': sv_coordinate_as_tensor.float(),  # this coordinate is
            'sv_location': sv_location_information_as_tensor.float(),
            'sv_time': sv_temporal_information_as_tensor.float(),
            'timestamps': timestamps.float(),
            'sat_transform': sat_transform,
            'bbox_information': bbox_information
        }

        return sample

    @staticmethod
    def build_transform(is_train, input_size, mean, std):
        # train transform
        interpol_mode = transforms.InterpolationMode.BICUBIC
        t = []
        if is_train:
            t.append(SentinelNormalize(mean, std))  # use specific Sentinel normalization to avoid NaN
            t.append(transforms.ToTensor())
            t.append(
                transforms.Resize(input_size, interpolation=interpol_mode),
            )
            t.append(transforms.RandomHorizontalFlip())
            return transforms.Compose(t)

        # eval transform

        t.append(SentinelNormalize(mean, std))
        t.append(transforms.ToTensor())
        t.append(
            transforms.Resize(input_size, interpolation=interpol_mode),
        )

        return transforms.Compose(t)


class SentinelColorJitterTransform:
    def __init__(self, brightness=0.2, contrast=0.2, rgb_indices=(0, 1, 2), probability=0.7):
        """
        Apply ColorJitter only to specified RGB bands in Sentinel-2 data with a given probability.
        :param brightness: brightness jitter factor
        :param contrast: contrast jitter factor
        :param rgb_indices: indices of RGB bands (default for Sentinel-2 is B2, B3, B4 -> indices (1, 2, 3))
        :param probability: probability of applying ColorJitter to the RGB bands
        """
        self.color_jitter = transforms.ColorJitter(brightness=brightness, contrast=contrast)
        self.rgb_indices = rgb_indices
        self.probability = probability

    def __call__(self, sample):
        # Apply ColorJitter with the specified probability
        if torch.rand(1).item() < self.probability:
            # Extract RGB bands and convert them to a PyTorch tensor format
            rgb_bands = sample[self.rgb_indices, :, :]
            rgb_bands = self.color_jitter(rgb_bands)

            # Replace original RGB bands with jittered versions
            sample[self.rgb_indices, :, :] = rgb_bands
        return sample


class SentinelStreetViewPairedImageDataset_for_INF(SatelliteDataset):
    """
    For v2, instead of using zero padding, I repeat the sv images to fill the missing images
    """
    label_types = ['value', 'one-hot']
    sentinel_mean = [1172.9397, 1378.0846, 1509.3327,
                     1750.0275, 2073.2758, 2207.3283, 2245.1629,
                     2284.6384, 2231.7283, 1899.1932]
    sentinel_std = [706.6250, 720.1862, 783.1424,
                    707.1962, 714.2782, 748.1827, 852.3585,
                    762.0849, 690.0165, 669.9036]

    sv_mean = [135.47184384, 142.28621995, 145.12271992]
    sv_std = [59.19978436, 59.35279219, 72.02295734]

    def __init__(self,
                 root_path,
                 meta_data_csv_path,
                 sat_transform,
                 sv_transform
                 ):
        """
        Creates dataset for multi-spectral single image classification.
        Usually used for fMoW-Sentinel dataset.
        :param csv_path: path to csv file.
        :param transform: pytorch Transform for transforms and tensor conversion
        :param years: List of years to take images from, None to not filter
        :param categories: List of categories to take images from, None to not filter
        :param label_type: 'values' for single label, 'one-hot' for one hot labels
        :param masked_bands: List of indices corresponding to which bands to mask out
        :param dropped_bands:  List of indices corresponding to which bands to drop from input image tensor
        """

        super().__init__(in_c=13)
        self.main_path = os.path.join(root_path, "images")  # image main path including many subfolders

        self.rs_main_path = os.path.join(root_path, "rs_data")
        self.all_cities = os.listdir(self.main_path)
        self.jpeg_paths = self.get_all_jpeg_paths(self.main_path)
        self.sat_transform = sat_transform
        self.sv_transform = sv_transform


    def get_all_jpeg_paths(self, main_path, samples=1000000):
        jpeg_paths = []

        for root, dirs, files in os.walk(main_path):
            for file in files:
                if file.lower().endswith(('.jpeg', '.jpg')):
                    jpeg_paths.append(os.path.join(root, file))
        jpeg_paths = random.sample(jpeg_paths, samples)
        return jpeg_paths

    def __len__(self):
        return len(self.jpeg_paths)

    def open_image(self, img_path):
        with rasterio.open(img_path) as data:
            img = data.read()  # (c, h, w)
        return img.transpose(1, 2, 0).astype(np.float32)  # (h, w, c)

    def process_time_stamp(self, year, month, hour, min_year=2002):
        return np.array([year - min_year, month - 1, hour])

    def extend_list(self, original_list, target_length):
        # For enlarge the number of sv images list
        extended_list = original_list[:]
        while len(extended_list) < target_length:
            tuple_to_add = random.choice(original_list)
            extended_list.append(tuple_to_add)

        return extended_list

    def __getitem__(self, idx):
        """
        Gets image (x,y) pair given index in dataset.
        :param idx: Index of (image, label) pair in dataset dataframe. (c, h, w)
        :return: Torch Tensor image, and integer label as a tuple.
        """
        sv_path = self.jpeg_paths[idx]
        sv_uuid = os.path.splitext(os.path.basename(sv_path))[0]
        city_id = os.path.basename(os.path.dirname(sv_path))
        meta_data_path = os.path.join("street_scapes/image_meta_data",
                                      f"{sv_uuid}.json")
        alternative_meta_data_path = os.path.join("data/sv_meta_data", f"{sv_uuid}_meta_data.json")
        rs_path = random.choice(glob(os.path.join(self.rs_main_path, city_id, "*.tif")))
        try:
            with open(meta_data_path, 'r') as f:
                meta_data_json = json.load(f)
                lon = meta_data_json['lon']
                lat = meta_data_json['lat']

        except:
            with open(alternative_meta_data_path, 'r') as f:
                meta_data_json = json.load(f)
                lon = meta_data_json['lon']
                lat = meta_data_json['lat']



        sv_image = skimage.io.imread(sv_path).astype(np.float32)
        sv_image_as_tensor = self.sv_transform(sv_image)
        sat_lon, sat_lat, transform, lon_min, lat_max, lon_max, lat_min = get_image_center_coordinates(
            rs_path)
        rs_image = self.open_image(rs_path)
        rs_image = np.concatenate((rs_image[:, :, 1:9], rs_image[:, :, 10:]), axis=2)
        rs_image_as_tensor = self.sat_transform(rs_image)

        bbox_information = torch.tensor([lon_min, lat_max, lon_max,
                                         lat_min])  # including the upper left coords and bottom right coords of the image

        sv_location_information_as_tensor = torch.tensor([lon, lat])
        sample = {
            'rs_image': rs_image_as_tensor.float(),
            "sv_image": sv_image_as_tensor.float(),
            'sv_location': sv_location_information_as_tensor.float(),
            'bbox_information': bbox_information
        }
        for key, value in sample.items():
            if isinstance(value, torch.Tensor) and torch.isnan(value).any():
                print(f"NaN detected in {key}")
                sample[key] = torch.nan_to_num(value, nan=0.0)
        return sample

    @staticmethod
    def build_transform_sat(is_train, input_size, mean, std):

        # train transform
        interpol_mode = transforms.InterpolationMode.BICUBIC
        t = []

        if is_train:
            t.append(SentinelNormalize(mean, std))  # use specific Sentinel normalization to avoid NaN
            t.append(transforms.ToTensor())
            t.append(
                transforms.Resize((input_size, input_size), interpolation=interpol_mode),
            )
            t.append(SentinelColorJitterTransform(brightness=0.2, contrast=0.2, rgb_indices=(1, 2, 3), probability=0.7))
            # t.append(transforms.RandomHorizontalFlip())
            return transforms.Compose(t)
        t.append(SentinelNormalize(mean, std))
        t.append(transforms.ToTensor())
        t.append(
            transforms.Resize((input_size, input_size), interpolation=interpol_mode),
        )

        # t.append(transforms.CenterCrop(input_size))
        return transforms.Compose(t)

    @staticmethod
    def build_transform_sv(is_train, input_size, mean, std):
        # train transform
        interpol_mode = transforms.InterpolationMode.BICUBIC
        t = []
        if is_train:
            t.append(ImageNormalizer(mean, std))
            t.append(transforms.ToTensor())
            t.append(
                transforms.Resize((input_size, input_size), interpolation=interpol_mode),
            )
            # t.append(transforms.Normalize(mean, std))  # use specific Sentinel normalization to avoid NaN
            t.append(transforms.RandomHorizontalFlip())
            t.append(transforms.RandomApply([transforms.ColorJitter(0.1, 0.1, 0.1, 0.1)], p=0.3))
            # transforms.RandomGrayscale(p=0.1)
            return transforms.Compose(t)
        t.append(ImageNormalizer(mean, std))
        t.append(transforms.ToTensor())
        t.append(
            transforms.Resize((input_size, input_size), interpolation=interpol_mode),
            # to maintain same ratio w.r.t. 224 images
        )

        return transforms.Compose(t)

def build_fmow_dataset(is_train: bool, args) -> SatelliteDataset:
    """
    Initializes a SatelliteDataset object given provided args.
    :param is_train: Whether we want the dataset for training or evaluation
    :param args: Argparser args object with provided arguments
    :return: SatelliteDataset object.
    """
    csv_path = os.path.join(args.train_path if is_train else args.test_path)
    if args.dataset_type == 'rgb':
        mean = CustomDatasetFromImages.mean
        std = CustomDatasetFromImages.std
        transform = CustomDatasetFromImages.build_transform(is_train, args.input_size, mean, std)
        dataset = CustomDatasetFromImages(csv_path, transform)
    elif args.dataset_type == "sentinel_sv_inf":
        sat_mean = SentinelStreetViewPairedImageDataset_for_INF.sentinel_mean
        sat_std = SentinelStreetViewPairedImageDataset_for_INF.sentinel_std
        sv_mean = SentinelStreetViewPairedImageDataset_for_INF.sv_mean
        sv_std = SentinelStreetViewPairedImageDataset_for_INF.sv_std
        sat_transform = SentinelStreetViewPairedImageDataset_for_INF.build_transform_sat(is_train, args.input_size,
                                                                                         sat_mean,
                                                                                         sat_std)
        sv_transform = SentinelStreetViewPairedImageDataset_for_INF.build_transform_sv(is_train,
                                                                                       args.input_sv_size,
                                                                                       sv_mean,
                                                                                       sv_std)
        dataset = SentinelStreetViewPairedImageDataset_for_INF(
            root_path="gsv_rs_project/street_scapes",
            meta_data_csv_path="gsv_rs_project/street_scapes/metadata_common_attributes.csv",
            sat_transform=sat_transform, sv_transform=sv_transform)


    else:
        raise ValueError(f"Invalid dataset type: {args.dataset_type}")
    print(dataset)

    return dataset
